# -*- coding:utf-8 -*-
import math
import os
import shutil
from collections import ChainMap, deque

import numpy as np
import pandas as pd
from fastdtw import fastdtw
from gensim.models import Word2Vec
from joblib import Parallel, delayed

from utils import partition_dict, preprocess_nxgraph
from alias import create_alias_table
from walker import BiasedWalker

class Struc2Vec():
    def __init__(self, nx_graph, walk_length=10, num_walks=100, workers=1, verbose=0, stay_prob=0.3, opt1_reduce_len=True, opt2_reduce_sim_calc=True, opt3_num_layers=None, temp_path='./temp_struc2vec/', reuse=False):
        self.graph = nx_graph
        self.idx2node, self.node2idx = preprocess_nxgraph(self.graph)
        self.idx = list(range(len(self.idx2node)))

        # 优化1，缩短degree sequence长度，使用pair(degree, # node of degree)
        self.opt1_reduce_len = opt1_reduce_len
        # 论文2，减少计算相似性的节点对，论文当中限制了每个节点每层只有log(n)对节点
        self.opt2_reduce_sim_calc = opt2_reduce_sim_calc
        # 论文中优化的第三项
        self.opt3_num_layers = opt3_num_layers

        self.reuse = reuse
        self.temp_path = temp_path

        if not os.path.exists(self.temp_path):
            os.mkdir(self.temp_path)
        if not reuse:
            shutil.rmtree(self.temp_path)
            os.mkdir(self.temp_path)
        
        self.create_context_graph(self.opt3_num_layers, workers, verbose)
        self.prepare_biased_walk()
        # TODO 需要继续把sentence补充完成
        self.walker = BiasedWalker(self.idx2node, self.temp_path)
        self.sentences = self.walker.simulate_walks(num_walks, walk_length, stay_prob, workers, verbose)

        self._embeddings = {}


    def create_context_graph(self, max_num_layers, workers=1, verbose=0):
        # 计算每个节点之间的距离
        pair_distances = self._compute_structural_distance(max_num_layers, workers, verbose)
        # 根据每对节点的距离并根据其层次，计算出每一层的距离
        layers_adj, layers_distances = self._get_layer_rep(pair_distances)
        pd.to_pickle(layers_adj, self.temp_path + "layers_adj.pkl")
        # 计算每个节点的转移距离
        layers_accept, layer_alias = self._get_transition_probs(layers_adj, layers_distances)
        pd.to_pickle(layer_alias, self.temp_path + "layers_alias.pkl")
        pd.to_pickle(layers_accept, self.temp_path + "layers_accept.pkl")

    def _compute_ordered_degreelist(self, max_num_layers):
        # 构建多层关于度排序的节点数量，用于计算节点之间的距离
        degree_list = {}
        vertices = self.idx
        for v in vertices:
            # 根据不同的node建立不同的树
            # degree_list =
            # v : {
            #   level_1: {
            #       (degree_1, number of node), 
            #       (degree_2, number of node)
            #   }, 
            #   level_2: {....}}
            degree_list[v] = self._get_order_degreelist_node(v, max_num_layers)
        return degree_list

    def _get_order_degreelist_node(self, root, max_num_layers=None):
        # 顶点对距离计算
        # 用一个循环去计算每个顶点对应的有序度序列。
        if max_num_layers is None:
            max_num_layers = float('inf')
        
        ordered_degree_sequence_dict = {}
        visited = [False] * len(self.graph.nodes())
        queue = deque()
        level = 0
        queue.append(root)
        visited[root] = True

        # 进行层次遍历
        while (len(queue) > 0 and level <= max_num_layers):
            count = len(queue) # 当前layer节点个数
            if self.opt1_reduce_len:
                degree_list = {}
            else:
                degree_list = []

            while (count > 0):
                top = queue.popleft()
                # 获取邻居个数（度）
                node = self.idx2node[top]
                degree = len(self.graph[node])
                if self.opt1_reduce_len:
                    # 统计当前层下的同样度的节点个数
                    degree_list[degree] = degree_list.get(degree, 0) + 1
                else:
                    degree_list.append(degree)
                
                # 将结论邻居放入队列当中
                for i in self.graph[node]:
                    nei_idx = self.node2idx[i]
                    if not visited[nei_idx]:
                        visited[nei_idx] = True
                        queue.append(nei_idx)
                count -= 1

            if self.opt1_reduce_len:
                # 按照degree进行排序
                order_degree_list = [(degree, freq) for degree, freq in degree_list.items()]
                order_degree_list.sort(key=lambda x: x[0])
            else:
                order_degree_list = sorted(degree_list)
            # level作为key保存order degree序列
            ordered_degree_sequence_dict[level] = order_degree_list
            level += 1
        
        return ordered_degree_sequence_dict

    def _compute_structural_distance(self, max_num_layers, workers=1, verbose=0):
        # return: 
        #     dtw_dict = {
        #         (v1, v2): {
        #             layer1: dist1, 
        #             layer2: dist2,
        #             ...
        #         },
        #         (v3, v4): {
        #             layer1: dist1, 
        #             layer2: dist2,
        #             ...
        #         }, ...
        #     }
        if os.path.exists(self.temp_path + 'structural_dist.pkl'):
            structural_dist = pd.read_pickle(self.temp_path+'structural_dist.pkl')
        else:
            if self.opt1_reduce_len:
                dist_func = cost_max
            else:
                dist_func = cost

            if os.path.exists(self.temp_path + 'degreelist.pkl'):
                degree_list = pd.read_pickle(self.temp_path + 'degreelist.pkl')
            else:
                # 构建每个节点，max_layers跳（keys: 第i跳，value: [(degree, #node)]）
                degree_list = self._compute_ordered_degreelist(max_num_layers)
                pd.to_pickle(degree_list, self.temp_path + 'degreelist.pkl')
            
            if self.opt2_reduce_sim_calc:
                # 创建节点相邻的degrees dict，以degree为key，其value有vertices,before,after
                degrees = self._create_vectors()
                degree_list_selected = {}
                vertices = {}
                n_nodes = len(self.graph.nodes())
                for v in self.idx:
                    # 获取相邻的节点（度数越相近越相邻），主要考虑结构相似度
                    node = self.idx2node[v]
                    nbs = get_vertices(v, len(self.graph[node]), degrees, n_nodes)
                    vertices[v] = nbs
                    # 递归获取一定量的相邻节点
                    # 获取v节点下，不同层的dergee信息（按照degree进行排序的节点个数）
                    degree_list_selected[v] = degree_list[v]
                    for n in nbs:
                        degree_list_selected[n] = degree_list[n]
            else:
                vertices = {}
                for v in degree_list:
                    vertices[v] = [vd for vd in degree_list.keys() if vd > v]

            results = Parallel(n_jobs=workers, verbose=verbose,)(delayed(compute_dtw_dist)(part_list, degree_list, dist_func) for part_list in partition_dict(vertices, workers))
            dtw_dist = dict((ChainMap(*results)))
            structural_dist = convert_dtw_struc_dist(dtw_dist)
            pd.to_pickle(structural_dist, self.temp_path + 'structural_dist.pkl')
        return structural_dist

    def _create_vectors(self):
        # 获取每个度值的节点，并按照degree进行排序，记录下该degree下的节点有哪些，前后两个度值的节点有哪些
        # degrees = {
        #   degree_value1: {
        #       'vertices': [v1, v2, ...], 
        #       'before': [v1, v2, ...], 
        #       'after': [v1, v2, ...]} 
        #   },
        #   degree_value2: {
        #       'vertices': [v1, v2, ...], 
        #       'before': [v1, v2, ...], 
        #       'after': [v1, v2, ...]} 
        #   }, 
        #   ....
        # }
        degrees = {} # key: graph中的度数，value: 度数的节点
        degrees_sorted = set() # 存放graph当中的度数
        for v in self.idx:
            degree = len(self.graph[self.idx2node[v]])
            degrees_sorted.add(degree)
            if (degree not in degrees):
                degrees[degree] = {}
                degrees[degree]['vertices'] = []
            degrees[degree]['vertices'].append(v)
        degrees_sorted = np.array(list(degrees_sorted), dtype='int')
        degrees_sorted = np.sort(degrees_sorted)

        l = len(degrees_sorted)
        for index, degree in enumerate(degrees_sorted):
            if (index > 0):
                # 记录比该度数小的前一个度数值
                degrees[degree]['before'] = degrees_sorted[index - 1]
            if (index < (l - 1)):
                # 记录比该度数大的后一个度数值
                degrees[degree]['after'] = degrees_sorted[index + 1]
        return degrees

    def _get_layer_rep(self, pair_distances):
        # 将pair_distances反转key和value
        # 首先按照layer进行索引，然后按照node pair索引
        layer_distances = {}
        layer_adj = {}
        for v_pair, layer_dist in pair_distances.items():
            for layer, distance in layer_dist.items():
                vx = v_pair[0]
                vy = v_pair[1]

                layer_distances.setdefault(layer, {})
                layer_distances[layer][vx, vy] = distance

                layer_adj.setdefault(layer, {})
                layer_adj[layer].setdefault(vx, [])
                layer_adj[layer].setdefault(vy, [])
                # 记录layer当中相邻的节点邻接表
                layer_adj[layer][vx].append(vy)
                layer_adj[layer][vy].append(vx)
        return layer_adj, layer_distances

    def _get_transition_probs(self, layers_adj, layers_distances):
        layers_alias = {}
        layers_accept = {}

        for layer in layers_adj:
            neighbors = layers_adj[layer]
            layer_distances = layers_distances[layer]
            node_alias_dict = {}
            node_accept_dict = {}
            norm_weights = {}

            for v, neighbors in neighbors.items():
                e_list = []
                sum_w = 0.0

                for n in neighbors:
                    if (v, n) in layer_distances:
                        wd = layer_distances[v, n]
                    else:
                        wd = layer_distances[n, v]
                    w = np.exp(-float(wd))
                    e_list.append(w)
                    sum_w += w
                # 对每个节点的邻居节点权值归一化
                e_list = [x / sum_w for x in e_list]
                norm_weights[v] = e_list
                accept, alias = create_alias_table(e_list)
                node_alias_dict[v] = alias
                node_accept_dict[v] = accept

            pd.to_pickle(norm_weights, self.temp_path + "norm_weights_distance-layer-" + str(layer)+'.pkl')

            layers_alias[layer] = node_alias_dict
            layers_accept[layer] = node_accept_dict
        return layers_accept, layers_alias

    def prepare_biased_walk(self,):
        # 论文当中3.2节构建graph上下文
        # 构建每一层内部的平均权值和每层每个节点的gamma函数映射，在后续随机游走的时候计算层与层之间的转移概率
        sum_weights = {} # sum layer edge weights
        sum_edges = {}    # number of edges
        average_weight = {} # each layer average weight of edges
        gamma = {}
        layer = 0
        while (os.path.exists(self.temp_path + 'norm_weights_distance-layer-' + str(layer) + '.pkl')):
            # 读取每层每个节点与邻居节点的权值
            probs = pd.read_pickle(self.temp_path + 'norm_weights_distance-layer-' + str(layer) + '.pkl')
            for v, list_weights in probs.items():
                sum_weights.setdefault(layer, 0)
                sum_edges.setdefault(layer, 0)
                sum_weights[layer] += sum(list_weights)
                sum_edges[layer] += len(list_weights)

            average_weight[layer] = sum_weights[layer] / sum_edges[layer]
            # 构建gamma函数: 计算当前层的节点v，在当前层的相邻节点的权值大于全局平均权值的节点个数
            gamma.setdefault(layer, {})
            for v, list_weights in probs.items():
                num_neighbours = 0
                # 筛选权值，统计大于该层的平均权值的权值节点个数（符合论文当中的优化）
                for w in list_weights:
                    if (w > average_weight[layer]):
                        num_neighbours += 1
                gamma[layer][v] = num_neighbours
            layer += 1
        pd.to_pickle(average_weight, self.temp_path + 'average_weight')
        pd.to_pickle(gamma, self.temp_path + 'gamma.pkl')

    def train(self, embed_size=128, window_size=5, workers=3, iter=5):
    
        # pd.read_pickle(self.temp_path+'walks.pkl')
        sentences = self.sentences

        print("Learning representation...")
        model = Word2Vec(sentences, size=embed_size, window=window_size, min_count=0, hs=1, sg=1, workers=workers,
                         iter=iter)
        print("Learning representation done!")
        self.w2v_model = model

        return model

    def get_embeddings(self,):
        if self.w2v_model is None:
            print("model not train")
            return {}

        self._embeddings = {}
        for word in self.graph.nodes():
            self._embeddings[word] = self.w2v_model.wv[word]

        return self._embeddings

def cost(a, b):
    ep = 0.5
    m = max(a, b) + ep
    mi = min(a, b) + ep
    return ((m / mi) - 1)

def cost_min(a, b):
    ep = 0.5
    m = max(a[0], b[0]) + ep
    mi = min(a[0], b[0]) + ep
    return ((m / mi) - 1) * min(a[1], b[1])

def cost_max(a, b):
    ep = 0.5
    m = max(a[0], b[0]) + ep
    mi = min(a[0], b[0]) + ep
    return ((m / mi) - 1) * max(a[1], b[1])

def convert_dtw_struc_dist(distances, start_layer=1):
    for vertices, layers in distances.items():
        keys_layers = sorted(layers.keys())
        start_layer = min(len(keys_layers), start_layer)
        for layer in range(0, start_layer):
            # 出列首元素
            keys_layers.pop(0)

        # 逐层合并最大层包含所有的distances
        for layer in keys_layers:
            layers[layer] += layers[layer - 1]
    return distances

def get_vertices(v, degree_v, degrees, n_nodes):
    # 通过传入的度数，找出对应的节点
    a_vertices_selected = 2 * math.log(n_nodes, 2) # 限制选择节点的个数
    vertices = []
    try:
        c_v = 0
        for v2 in degrees[degree_v]['vertices']:
            if (v != v2):
                vertices.append(v2)
                c_v += 1
                if (c_v > a_vertices_selected):
                    raise StopIteration

        # 判断是否为最小的度
        if ('before' not in degrees[degree_v]):
            degree_b = -1
        else:
            degree_b = degrees[degree_v]['before']
        
        # 判断是否为最大的度
        if ('after' not in degrees[degree_v]):
            degree_a = -1
        else:
            degree_a = degrees[degree_v]['after']

        # graph中只有一个度
        if (degree_b == -1 and degree_a == -1):
            raise StopIteration
        degree_now = verify_degrees(degrees, degree_v, degree_a, degree_b)

        while True:
            for v2 in degrees[degree_now]['vertices']:
                if (v != v2):
                    vertices.append(v2)
                    c_v += 1
                    if (c_v > a_vertices_selected):
                        raise StopIteration

            if (degree_now == degree_b):
                if ('before' not in degrees[degree_b]):
                    degree_b = -1
                else:
                    degree_b = degrees[degree_b]['before']
            else:
                if ('after' not in degrees[degree_a]):
                    degree_a = -1
                else:
                    degree_a = degrees[degree_a]['after']
            # TODO：是否需要删除这个判断，因为之前已经判断图中只有一个度
            if (degree_b == -1 and degree_a == -1):
                raise StopIteration
            degree_now = verify_degrees(degrees, degree_v, degree_a, degree_b)
    except StopIteration:
        return list(vertices)
    return list(vertices)

def verify_degrees(degrees, degree_v_root, degree_a, degree_b):
    # 选择其中一个degree的序列作为补充，要求绝对值只差相差最小
    if (degree_b == -1):
        degree_now = degree_a
    elif (degree_a == -1):
        degree_now = degree_b
    # 选择距离较近的root的那个度
    elif (abs(degree_b - degree_v_root) < abs(degree_a - degree_v_root)):
        degree_now = degree_b
    else:
        degree_now = degree_a
    return degree_now

def compute_dtw_dist(part_list, degree_list, dist_func):
    # part_list: {v: nbs}
    # degree_list =
    # v : {
    #   level_1: {
    #            degree_1, number of node, 
    #            degree_2, number of node
    #            }, 
    #   level_2: {....}}
    # return: 
    # dtw_dict = {
    #   (v1, v2): {
    #       layer1: dist1, 
    #       layer2: dist2,
    #       ...
    #   },
    #   (v3, v4): {
    #       layer1: dist1, 
    #       layer2: dist2,
    #       ...
    #   }, ...
    # }
    dtw_dict = {}
    # 获取v1的degree_list和v1邻居节点degree_list
    for v1, nbs in part_list:
        lists_v1 = degree_list[v1] # orderd degree list of v1
        for v2 in nbs:
            lists_v2 = degree_list[v2] # orderd degree list of v2
            max_layer = min(len(lists_v1), len(lists_v2))
            dtw_dict[v1, v2] = {}
            for layer in range(0, max_layer):
                # v1节点与邻居节点的距离（按照layer计算）
                dist, path = fastdtw(lists_v1[layer], lists_v2[layer], radius=1, dist=dist_func)
                dtw_dict[v1, v2][layer] = dist
    return dtw_dict


